{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mf3kOv1YMB5y"
      },
      "source": [
        "Copyright 2021 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-rOdskBSMfQN"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wB8E6bWergh9"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from colabtools import adhoc_import\n",
        "from contextlib import ExitStack\n",
        "\n",
        "ADHOC = True\n",
        "CLIENT = 'fig-export-fig_tree-change-451-3e0a679e9746'\n",
        "\n",
        "import tensorflow_probability.substrates.jax as tfp\n",
        "from fun_mc import using_jax as fun_mc\n",
        "\n",
        "tfd = tfp.distributions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qwbxRbDmFzOQ"
      },
      "source": [
        "# Variance analysis"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SHi-hjl2vtlb"
      },
      "outputs": [],
      "source": [
        "n_chains = 10240\n",
        "n_super_chains = 4\n",
        "n_steps = 100\n",
        "n_sub_chains = n_chains // n_super_chains\n",
        "\n",
        "# ESS by looking at raw chains\n",
        "ess_vals = []\n",
        "# ESS computed by looking at super chains\n",
        "pooled_ess_vals = []\n",
        "# Like pooled ESS, but we account for the number of sub chains used\n",
        "nested_ess_vals = []\n",
        "\n",
        "for seed in jax.random.split(jax.random.PRNGKey(0), 200):\n",
        "  chain = jax.random.normal(seed, [n_steps, n_chains])\n",
        "  pooled_chain = chain.reshape([n_steps, n_sub_chains, n_super_chains])\n",
        "\n",
        "  between = chain.mean(0).var(0, ddof=1)\n",
        "  overall = chain.var((0, 1), ddof=1)\n",
        "  ess_vals.append(overall / between)\n",
        "\n",
        "  super_chain = pooled_chain.mean(1)\n",
        "  pooled_between = super_chain.mean(0).var(0, ddof=1)\n",
        "  pooled_overall = super_chain.var((0, 1), ddof=1)\n",
        "  pooled_ess_vals.append(pooled_overall / pooled_between)\n",
        "\n",
        "  if True:\n",
        "    nested_between = pooled_chain.mean((0, 1)).var(0, ddof=1)\n",
        "    nested_overall = pooled_chain.var((0, 1, 2), ddof=1)\n",
        "    nested_ess_vals.append((nested_overall / nested_between))\n",
        "  else:\n",
        "    # Calculation from Charles's notebook.\n",
        "    mean_chain = pooled_chain.mean(0)\n",
        "    mean_super_chain = pooled_chain.mean((0, 1))\n",
        "    variance_chain = pooled_chain.var(0, ddof=1)\n",
        "    variance_nested_chain = mean_chain.var(0, ddof=1) + variance_chain.mean(0)\n",
        "\n",
        "    within_var = variance_nested_chain.mean(0)\n",
        "    between_var = mean_super_chain.var(0, ddof=1)\n",
        "\n",
        "    nested_ess_vals.append((1 + within_var / between_var))\n",
        "\n",
        "ess_vals = jnp.array(ess_vals)\n",
        "pooled_ess_vals = jnp.array(pooled_ess_vals)\n",
        "nested_ess_vals = jnp.array(nested_ess_vals)\n",
        "# We can also normalize the nested ESS values to take into account that super\n",
        "# chains are larger than regular chains. This interprets the pooling as an\n",
        "# ostensibly denoised ESS estimator.\n",
        "nested_normalized_ess_vals = nested_ess_vals / n_sub_chains\n",
        "\n",
        "# These are actually rhat - 1\n",
        "rhat_vals = 1 / ess_vals\n",
        "pooled_rhat_vals = 1 / pooled_ess_vals\n",
        "nested_rhat_vals = 1 / nested_ess_vals\n",
        "# It's not clear what the meaning of this is.\n",
        "nested_normalized_rhat_vals = 1 / nested_normalized_ess_vals\n",
        "\n",
        "# Expected per-chain ESS is n_steps\n",
        "print('ESS mean + std:', ess_vals.mean(), ess_vals.std())\n",
        "print('pooled ESS mean + std:', pooled_ess_vals.mean(), pooled_ess_vals.std())\n",
        "print('nested ESS mean + std:', nested_ess_vals.mean(), nested_ess_vals.std())\n",
        "print('nested normalized ESS mean + std:', nested_normalized_ess_vals.mean(), nested_normalized_ess_vals.std())\n",
        "print()\n",
        "print('rhat - 1 mean + std:', rhat_vals.mean(), rhat_vals.std())\n",
        "print('pooled rhat - 1 mean + std:', pooled_rhat_vals.mean(), pooled_rhat_vals.std())\n",
        "print('nested rhat - 1 mean + std:', nested_rhat_vals.mean(), nested_rhat_vals.std())\n",
        "print('nested normalized rhat - 1 mean + std:', nested_normalized_rhat_vals.mean(), nested_normalized_rhat_vals.std())\n",
        "\n",
        "fig = plt.figure(figsize=(24, 6))\n",
        "plt.subplot(2, 4, 1)\n",
        "plt.title('log10 ESS')\n",
        "plt.hist(jnp.log10(ess_vals), histtype='step', density=True, bins=50)\n",
        "\n",
        "plt.subplot(2, 4, 2)\n",
        "plt.title('log10 pooled ESS')\n",
        "plt.hist(jnp.log10(pooled_ess_vals), histtype='step', density=True, bins=50)\n",
        "\n",
        "plt.subplot(2, 4, 3)\n",
        "plt.title('log10 nested ESS')\n",
        "plt.hist(jnp.log10(nested_ess_vals), histtype='step', density=True, bins=50);\n",
        "\n",
        "plt.subplot(2, 4, 4)\n",
        "plt.title('log10 nested normalized ESS')\n",
        "plt.hist(jnp.log10(nested_normalized_ess_vals), histtype='step', density=True, bins=50);\n",
        "\n",
        "plt.subplot(2, 4, 5)\n",
        "plt.title('log10 rhat - 1')\n",
        "plt.hist(jnp.log10(rhat_vals), histtype='step', density=True, bins=50)\n",
        "\n",
        "plt.subplot(2, 4, 6)\n",
        "plt.title('log10 pooled rhat - 1')\n",
        "plt.hist(jnp.log10(pooled_rhat_vals), histtype='step', density=True, bins=50)\n",
        "\n",
        "plt.subplot(2, 4, 7)\n",
        "plt.title('log10 nested rhat - 1')\n",
        "plt.hist(jnp.log10(nested_rhat_vals), histtype='step', density=True, bins=50);\n",
        "\n",
        "plt.subplot(2, 4, 8)\n",
        "plt.title('log10 nested normalized rhat - 1')\n",
        "plt.hist(jnp.log10(nested_normalized_rhat_vals), histtype='step', density=True, bins=50);\n",
        "\n",
        "fig.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7CzwZYfZF1jM"
      },
      "source": [
        "# MCMC test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1j78wa-TVW3K"
      },
      "outputs": [],
      "source": [
        "dist = tfd.Normal(0., 1.)\n",
        "n_chains = 10240\n",
        "n_super_chains = 8\n",
        "n_steps = 100\n",
        "n_sub_chains = n_chains // n_super_chains\n",
        "\n",
        "def target_log_prob_fn(x):\n",
        "  return dist.log_prob(x), ()\n",
        "\n",
        "\n",
        "def kernel(hmc_state, seed):\n",
        "  hmc_seed, seed = jax.random.split(seed)\n",
        "  hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step(\n",
        "      hmc_state,\n",
        "      target_log_prob_fn=target_log_prob_fn,\n",
        "      step_size=0.5,\n",
        "      num_integrator_steps=1,\n",
        "      seed=hmc_seed)\n",
        "  return (hmc_state, seed), (hmc_state.state, hmc_extra.is_accepted)\n",
        "\n",
        "\n",
        "\n",
        "init_x = dist.sample([n_chains], seed=jax.random.PRNGKey(0))\n",
        "\n",
        "_, (chain, is_accepted) = fun_mc.trace((fun_mc.hamiltonian_monte_carlo_init(init_x,\n",
        "    target_log_prob_fn), jax.random.PRNGKey(0)), kernel, 10000)\n",
        "\n",
        "init_x2 = dist.sample([n_super_chains], seed=jax.random.PRNGKey(3))\n",
        "init_x2 = jnp.repeat(init_x2, n_sub_chains)\n",
        "#init_x2 = dist.sample([num_chains], seed=jax.random.PRNGKey(3))\n",
        "init_x2 = init_x2.reshape([n_super_chains, n_sub_chains])\n",
        "\n",
        "_, (chain2, is_accepted2) = fun_mc.trace((fun_mc.hamiltonian_monte_carlo_init(init_x2,\n",
        "    target_log_prob_fn), jax.random.PRNGKey(3)), kernel, 10000)\n",
        "\n",
        "chain = jnp.concatenate([init_x[jnp.newaxis], chain], 0)\n",
        "chain2 = jnp.concatenate([init_x2[jnp.newaxis], chain2], 0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8udM_46EsKaC"
      },
      "outputs": [],
      "source": [
        "plt.plot(chain[:, :4])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QoMhF_n0tSQ2"
      },
      "outputs": [],
      "source": [
        "plt.plot(chain2[:, 0, :4])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "krTTXX9lziZV"
      },
      "outputs": [],
      "source": [
        "chain2[0].mean(-1).var(0), chain[0].var(0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I1oq9lixtXC1"
      },
      "outputs": [],
      "source": [
        "between_reg = (jnp.cumsum(chain, 0) / jnp.arange(1, chain.shape[0] + 1)[:, jnp.newaxis]).var(1)\n",
        "#between_reg = (jnp.cumsum(chain2, 0) / jnp.arange(1, chain.shape[0] + 1)[:, jnp.newaxis, jnp.newaxis]).var((1, 2))\n",
        "super_chain = chain2.mean(-1)\n",
        "between_nested = (jnp.cumsum(super_chain, 0) / jnp.arange(1, super_chain.shape[0] + 1)[:, jnp.newaxis]).var(1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "irascTzltwiv"
      },
      "outputs": [],
      "source": [
        "plt.title('between chain variance')\n",
        "plt.plot(between_reg, label='regular chain')\n",
        "plt.plot(between_nested, label='super chain')\n",
        "plt.plot(between_reg / n_sub_chains, label='regular chain / n_sub_chains')\n",
        "plt.axhline(1e-2, ls='--', color='black', lw=2)\n",
        "plt.xscale('log')\n",
        "plt.yscale('log')\n",
        "plt.xlabel('chain length')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C5LEwbtr_UHG"
      },
      "outputs": [],
      "source": [
        "between_reg2 = between_reg\n",
        "between_nested2 = between_nested\n",
        "n_sub_chains2 = n_sub_chains"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y2Fp5rLNTxkZ"
      },
      "outputs": [],
      "source": [
        ""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cxG6z3UuSwW4"
      },
      "outputs": [],
      "source": [
        "print(n_sub_chains)\n",
        "print(n_sub_chains2)\n",
        "plt.figure(figsize=(12, 8))\n",
        "plt.title('between chain variance')\n",
        "\n",
        "plt.plot(between_reg, label='regular chain')\n",
        "plt.plot(between_nested, label='super chain')\n",
        "plt.plot(between_reg / n_sub_chains, label='regular chain / n_sub_chains')\n",
        "\n",
        "plt.plot(between_reg2, label='regular chain 2')\n",
        "plt.plot(between_nested2, label='super chain 2')\n",
        "plt.plot(between_reg2 / n_sub_chains2, label='regular chain 2 / n_sub_chains 2')\n",
        "\n",
        "plt.axhline(1e-2, ls='--', color='black', lw=2)\n",
        "plt.xscale('log')\n",
        "plt.yscale('log')\n",
        "plt.xlabel('chain length')\n",
        "plt.legend()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "Nested R-hat.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "199pnTb5NJtFaLozNAlp3todlh4yjHD6g",
          "timestamp": 1632344135035
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
